/***********************************************************************
 *                               header                                *
 ***********************************************************************/

#include "mex.h"
#include "math.h"
#include <stdio.h>
#include <stdarg.h>
#include <inttypes.h>

#define PWMAX(a, b) ( (a) > (b) ? (a) : (b) )
#define PWMIN(a, b) ( (a) > (b) ? (b) : (a) )

struct squaremResults {
    double norm;
    uint64_t iter;
};

struct squaremResults squarem(double tol, uint64_t maxiter, double *x, uint64_t K, void f(), ...);
bool isany(double *x, uint64_t K, double v);

void mktShare(double *pSharesM, double *expmvaloldM, va_list args);

void contrMap(
    double tol,
    double maxiter,
    double *expmvalold,
    double *expmval,
    double *expmu,
    uint64_t nJ,
    uint64_t nI,
    uint64_t nM,
    double *cdindex,
    double *shares,
    double rho,
    double *mktnorm,
    double *mktiter
);

/***********************************************************************
 *                            main function                            *
 ***********************************************************************/

void mexFunction( int nlhs, mxArray *plhs[],
                  int nrhs, const mxArray *prhs[])
{
    uint64_t nJ;         /* number of product-mkt observations        */
    uint64_t nI;         /* number of individual draws                */
    uint64_t nM;         /* number of markets                         */

    double tol;          /* contraction mapping tolerance             */
    double maxiter;      /* max iterations in contraction mapping     */
    double rho;          /* nested logit parameter                    */

    double *expmvalold;  /* initial vector of mean valuations         */
    double *expmu;       /* matrix of individual-specific deviations  */
    double *cdindex;     /* vector identifies last obs of each market */
    double *shares;      /* vector of observed market shares          */
    double *expmval;     /* output vector of updated mean valuations  */
    double *mktnorm;     /* output vector of market tolerances        */
    double *mktiter;     /* output vector of market # iterations      */

    tol        = mxGetScalar(prhs[0]); /* get the value of tol                        */
    maxiter    = mxGetScalar(prhs[1]); /* get the value of maxiter                    */
    expmvalold = mxGetPr(prhs[2]);     /* create a pointer to the real exmvalold data */

    /* get dimensions of the inputs */
    nJ = mxGetM(prhs[3]);
    nI = mxGetN(prhs[3]);
    nM = mxGetM(prhs[4]);

    expmu   = mxGetPr(prhs[3]);     /* create a pointer to the real expmu data        */
    cdindex = mxGetPr(prhs[4]);     /* create a pointer to the real cdindex data      */
    shares  = mxGetPr(prhs[5]);     /* create a pointer to the real share data        */
    rho     = mxGetScalar(prhs[6]); /* create a pointer to the nested logit parameter */

    /* create the output data */
    plhs[0] = mxCreateDoubleMatrix((uint64_t)nJ,1,mxREAL);
    plhs[1] = mxCreateDoubleMatrix((uint64_t)nM,1,mxREAL);
    plhs[2] = mxCreateDoubleMatrix((uint64_t)nM,1,mxREAL);

    /* get a pointer to each of the data outputs */
    expmval = mxGetPr(plhs[0]);
    mktnorm = mxGetPr(plhs[1]);
    mktiter = mxGetPr(plhs[2]);

    /* contraction mapping via squarem */
    contrMap(
        tol,
        maxiter,
        expmvalold,
        expmval,
        expmu,
        (uint64_t)nJ,
        (uint64_t)nI,
        (uint64_t)nM,
        cdindex,
        shares,
        rho,
        mktnorm,
        mktiter
    );
}

void contrMap(
    double tol,
    double maxiter,
    double *expmvalold,
    double *expmval,
    double *expmu,
    uint64_t nJ,
    uint64_t nI,
    uint64_t nM,
    double *cdindex,
    double *shares,
    double rho,
    double *mktnorm,
    double *mktiter)
{
    struct squaremResults results;

    /* Assigning variable names */
    uint64_t j;
    uint64_t i;
    uint64_t mkt;
    uint64_t nJM;
    uint64_t nJM_max;

    int64_t mStart;
    int64_t mEnd;
    int64_t iter;

    double  *sharesM;
    mxArray *expmvaloldMx;
    double  *expmvaloldM;
    mxArray *expmuMx;
    double  *expmuM;
    mxArray *exputilMx;
    double  *exputilM;

    double oneLessRho = (1 - rho);
    double invOneLessRho = 1 / oneLessRho;

    nJM_max = cdindex[0];
    for (mkt = 1; mkt < nM; mkt++) {
        nJM = cdindex[mkt] - cdindex[mkt - 1];
        if (nJM > nJM_max) nJM_max = nJM;
    }

    expmvaloldMx = mxCreateDoubleMatrix((uint64_t)nJM_max,1,mxREAL);
    expmvaloldM  = mxGetPr(expmvaloldMx);

    expmuMx = mxCreateDoubleMatrix(nJM_max*nI,1,mxREAL);
    expmuM  = mxGetPr(expmuMx);

    exputilMx = mxCreateDoubleMatrix(nJM_max*nI,1,mxREAL);
    exputilM  = mxGetPr(exputilMx);

    for (mkt=0; mkt<nM; mkt++) {
        if (mkt==0){
            mStart = 0;
            nJM = cdindex[mkt];
        }
        else {
            mStart = cdindex[mkt - 1];
            nJM = cdindex[mkt] - cdindex[mkt - 1];
        }
        mEnd = cdindex[mkt];

        sharesM = shares + mStart;
        expmvaloldM = expmvalold + mStart;
        // for (j = mStart; j < mEnd; j++) {
        //     expmvaloldM[j - mStart] = log(expmvalold[j]);
        // }
        for (i = 0; i < nI; i++){
            for (j = mStart; j < mEnd; j++) {
                // expmuM[nJM * i + j - mStart] = log(expmu[nJ * i + j]);
                expmuM[nJM * i + j - mStart] = expmu[nJ * i + j];
            }
        }

        results = squarem(tol, maxiter, expmvaloldM, nJM, mktShare,
                          sharesM, exputilM, expmuM, nJM, nI, oneLessRho, invOneLessRho);

        for (j = mStart; j < mEnd; j++) {
            expmval[j] = exp(expmvaloldM[j - mStart]);
        }

        mktnorm[mkt] = results.norm;
        mktiter[mkt] = results.iter;
    }

    mxDestroyArray(expmvaloldMx);
    mxDestroyArray(expmuMx);
    mxDestroyArray(exputilMx);
}

void mktShare(double *pSharesM, double *expmvaloldM, va_list args)
{

    double *sharesM      = va_arg(args, double *);
    double *exputilM     = va_arg(args, double *);
    double *expmuM       = va_arg(args, double *);
    uint64_t nJM         = va_arg(args, uint64_t);
    uint64_t nI          = va_arg(args, uint64_t);
    double oneLessRho    = va_arg(args, double);
    double invOneLessRho = va_arg(args, double);

    uint64_t ii, k;
    double incAux, incVal, incValPow;

    for (k = 0; k < nJM; k++) {
        pSharesM[k] = 0;
    }

    for (ii = 0; ii < nI; ii++) {
        incVal = 0;
        for (k = 0; k < nJM; k++) {
            // exputilM[(ii * nJM) + k] = pow(expmvaloldM[k] * expmuM[(ii*nJM)+k], invOneLessRho);
            exputilM[(ii * nJM) + k] = exp(invOneLessRho * (expmvaloldM[k] + expmuM[(ii * nJM) + k]));
            incVal += exputilM[(ii * nJM) + k];
        }
        incValPow = pow(incVal, oneLessRho);
        incAux = (1 / incVal) * (incValPow / (1 + incValPow));
        for (k=0; k<nJM; k++) {
            pSharesM[k] += exputilM[(ii * nJM) + k] * incAux;
        }
    }

    for (k = 0; k < nJM; k++) {
        pSharesM[k] /= (double) nI;
        pSharesM[k] = expmvaloldM[k] + oneLessRho * log(sharesM[k]/pSharesM[k]);
    }
}

/***********************************************************************
 *                               squarem                               *
 ***********************************************************************/

struct squaremResults squarem(double tol, uint64_t maxiter, double *x, uint64_t K, void f(), ...)
{
    struct squaremResults results;
    va_list args;
    double alpha, sr2, sv2;
    double diff, norm, stepmax, stepmin, mstep;
    double NaN = mxGetNaN();
    uint64_t k, iter = 0;

    double *x1 = calloc(K, sizeof *x1);
    double *x2 = calloc(K, sizeof *x2);
    double *q1 = calloc(K, sizeof *q1);
    double *q2 = calloc(K, sizeof *q2);

    norm    = 1.0;
    stepmax = 1.0;
    stepmin = 1.0;
    mstep   = 4.0;

    while ( (norm > tol) && (iter++ < maxiter) ) {
        va_start(args, f); f(x1, x,  args); va_end(args);
        va_start(args, f); f(x2, x1, args); va_end(args);

        sr2 = 0;
        sv2 = 0;
        for (k = 0; k < K; k++) {
            q1[k] = x1[k] - x[k];
            q2[k] = x2[k] - x1[k];
            diff  = q2[k] - q1[k];
            sr2  += q1[k] * q1[k];
            sv2  += diff * diff;
        }

        alpha = PWMAX(stepmin, PWMIN(stepmax, pow(sr2/sv2, 0.5)));
        for (k = 0; k < K; k++) {
            q1[k] = x[k] + 2 * alpha * q1[k] + alpha * alpha * (q2[k] - q1[k]);
        }

        va_start(args, f); f(x1, q1, args); va_end(args);
        if ( isany(x1, K, NaN) ) {
            for (k = 0; k < K; k++) {
                x1[k] = x2[k];
            }
        }

        if (alpha == stepmax) {
            stepmax *= mstep;
        }

        if ( (alpha == stepmin) && (alpha < 0) ) {
            stepmin *= mstep;
        }

        norm = 0;
        for (k = 0; k < K; k++) {
            norm += fabs(x1[k] - x[k]);
            x[k]  = x1[k];
        }
        norm /= (double) K;
    }

    free(x1);
    free(x2);
    free(q1);
    free(q2);

    results.norm = norm;
    results.iter = iter;

    return(results);
}

bool isany(double *x, uint64_t K, double v)
{
    double *xptr;
    bool isany = false;
    for (xptr = x; xptr < x + K; xptr++) {
        if ( *x == v ) {
            isany = true;
            break;
        }
    }
    return(isany);
}
